package org.zjvis.datascience.common.algo;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.google.common.base.Joiner;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.zjvis.datascience.common.constant.SqlTemplate;
import org.zjvis.datascience.common.enums.AlgEnum;
import org.zjvis.datascience.common.enums.SubTypeEnum;
import org.zjvis.datascience.common.sql.SqlHelper;
import org.zjvis.datascience.common.util.ToolUtil;
import org.zjvis.datascience.common.vo.TaskVO;

import java.util.ArrayList;
import java.util.List;

/**
 * @description PCA 降维算子模板类
 * @date 2021-12-24
 */
public class PcaDenseAlg extends BaseAlg {

    private final static Logger logger = LoggerFactory.getLogger("PcaDenseAlg");

    private static String SQL_TPL_MADLIB = "select * from \"%s\".\"pca_dense\"('%s', '%s', '%s', '%s', '%s', '%s', '%s', %s, %s, %s)";

    private static String SQL_TPL_MADLIB_SAMPLE = "select * from \"%s\".\"pca_dense_sample\"('CREATE VIEW %s AS SELECT %s from %s where \"%s\" <= %s', '%s', '%s', '%s', '%s', '%s', '%s', %s, %s, %s)";

    private static String SQL_TPL_SPARK = "pca -s %s -f %s -t %s -uk %d -k %s -idcol %s";

    private final String TPL_FILENAME = "template/algo/pca_dense.json";

    public PcaDenseAlg() {
        super(AlgEnum.PCA_DENSE.name(), SubTypeEnum.DIMENSION_REDUCTION.getVal(),
                SubTypeEnum.DIMENSION_REDUCTION.getDesc());
        this.maxParentNumber = 1;
    }

    public String getPcaDenseSql(String sourceTable, String modelTable, String resultTable,
                                 String residualTable, String summaryTable, String featureCols, int k, float proportion,
                                 int isK, long timeStamp, String sampleTable) {

        if (StringUtils.isNotEmpty(sampleTable)) {
            // 采样操作，走madlib引擎
            List<String> fields = new ArrayList<>();
            fields.add(String.format("\"%s\"", ID_COL));
            if (StringUtils.isEmpty(featureCols)) {
                return StringUtils.EMPTY;
            }
            String[] tmps = featureCols.split(",");
            for (String item : tmps) {
                fields.add(String.format("\"%s\"", item));
            }
            return String.format(SQL_TPL_MADLIB_SAMPLE, SqlTemplate.SCHEMA, sampleTable,
                    Joiner.on(",").join(fields),
                    sourceTable, ID_COL, SAMPLE_NUMBER, modelTable, resultTable,
                    residualTable, summaryTable, ID_COL, featureCols, k, proportion, isK);
        } else {
            // 全量, 根据配置
            if (getEngine().isMadlib()) {
                return String.format(SQL_TPL_MADLIB, SqlTemplate.SCHEMA, sourceTable, modelTable,
                        resultTable, residualTable, summaryTable, ID_COL, featureCols, k, proportion,
                        isK);
            } else if (getEngine().isSpark()) {
                return String
                        .format(SQL_TPL_SPARK, sourceTable, featureCols, resultTable, timeStamp, k,
                                ID_COL);
            }
        }
        return StringUtils.EMPTY;
    }

    public String initSql(JSONObject json, List<SqlHelper> sqlHelpers, long timeStamp,
                          String engineName) {
        this.engineName = engineName;
        String sourceTable = json.getString("source_table");
        sourceTable = ToolUtil.alignTableName(sourceTable, timeStamp);
        String outTable = json.getString("out_table_rename");
        String modelTable = ToolUtil.alignTableName(json.getString("out_table_model"), timeStamp);
        String resultTable = ToolUtil.alignTableName(outTable, timeStamp);
        String residualTable = ToolUtil.alignTableName(json.getString("residualTable"), timeStamp);
        String summaryTable = ToolUtil.alignTableName(json.getString("summaryTable"), timeStamp);
        int isK = 0;
        String isDim = json.getString("is_dim");
        if (isDim.equals("dimension")) {
            isK = 1;
        } else if (isDim.equals("proportion")) {
            isK = 0;
        }
        int k = json.getInteger("dimension_num");
        float proportion = json.getFloat("proportion");

        JSONArray features = json.getJSONArray("feature_cols");
        String featureCols = this.getFeatureColsStr(features);
        String sampleTable = "";
        if (!json.containsKey("isSample") || json.getString("isSample").equals("SUCCESS") || json
                .getString("isSample").equals("FAIL")) {
            sampleTable = resultTable.replace("solid_", "view_");
            json.put("isSample", "CREATE");
        }

        String sql = this
                .getPcaDenseSql(sourceTable, modelTable, resultTable, residualTable, summaryTable,
                        featureCols, k, proportion, isK, timeStamp, sampleTable);
        logger.debug("initSql sql={}", sql);
        return sql;
    }

    public void initTemplate(JSONObject data) {
        JSONArray jsonArray = getTemplateParamList(TPL_FILENAME);
        data.put("setParams", jsonArray);
        baseInitTemplate(data);
        JSONArray validate = new JSONArray();
        validate.add("feature_cols,number");
        data.put("validate", validate);
    }

    public void defineOutput(TaskVO vo) {
        JSONObject jsonObject = vo.getData();
        String outTablePrefix = jsonObject.getString("out_table");
        String tableName = String
                .format(SqlTemplate.OUT_TABLE_NAME, outTablePrefix, vo.getPipelineId(), vo.getId());
        jsonObject.put("out_table_rename", tableName);
        JSONArray input = jsonObject.getJSONArray("input");
        if (input == null || input.size() == 0) {
            logger.warn("input is empty");
            return;
        }
        this.checkBoxSelectFilter(jsonObject, "number", FEATURE_COLS);
        this.supplementForCheckbox(jsonObject, TPL_FILENAME, 1, vo);
        if (!jsonObject.containsKey("feature_cols") && !jsonObject.containsKey("dimension_num")) {
            logger.warn("feature_cols not exits");
            vo.setData(jsonObject);
            return;
        }
        JSONArray output = new JSONArray();
        JSONObject resultItem = new JSONObject();
        jsonObject.put("source_table", input.getJSONObject(0).getString("tableName"));
        JSONArray resultCols = new JSONArray();
        JSONArray resultColTypes = new JSONArray();
        resultCols.add(ID_COL);
        resultColTypes.add("integer");
        if (!jsonObject.containsKey("dimension_num")) {
            vo.setData(jsonObject);
            return;
        }
        int k = jsonObject.getInteger("dimension_num");
        for (int i = 0; i < k; ++i) {
            resultCols.add(String.format("f%s", i + 1));
            resultColTypes.add("numeric");
        }
        resultItem.put("tableName", tableName);
        resultItem
                .put("nodeName", vo.getName() == null ? AlgEnum.PCA_DENSE.toString() : vo.getName());
        resultItem.put("tableCols", resultCols);
        resultItem.put("columnTypes", resultColTypes);
        JSONObject modelItem = new JSONObject();
        String[] modelCols = new String[]{"row_id", "principal_components", "std_dev",
                "proportion"};
        String[] modelColumnTypes = new String[]{"integer", "array", "double precision",
                "double precision"};
        String modelTable = String
                .format(SqlTemplate.OUT_TABLE_NAME, outTablePrefix + "_model", vo.getPipelineId(),
                        vo.getId());
        jsonObject.put("out_table_model", modelTable);
        modelItem.put("tableName", modelTable);
        modelItem.put("tableCols", modelCols);
        modelItem
                .put("nodeName", vo.getName() == null ? AlgEnum.PCA_DENSE.toString() : vo.getName());
        modelItem.put("columnTypes", modelColumnTypes);
        this.setSubTypeForOutput(resultItem);
        output.add(resultItem);

        jsonObject.put("output", output);
        jsonObject.put("residualTable", String
                .format(SqlTemplate.OUT_TABLE_NAME, outTablePrefix + "_residual", vo.getPipelineId(),
                        vo.getId()));
        jsonObject.put("summaryTable", String
                .format(SqlTemplate.OUT_TABLE_NAME, outTablePrefix + "_resultsummary",
                        vo.getPipelineId(), vo.getId()));
        // TODO Support proportion way
        jsonObject.put("is_dim", "dimension");
        jsonObject.put("proportion", 0.9);
        vo.setData(jsonObject);
    }
}
